Notebook metadata changed
xxxxxxxxxx
38
 
1
{
2
  "kernelspec": {
3
    "display_name": "Python 3 (ipykernel)",
4
    "language": "python",
5
    "name": "python3"
6
  },
7
  "language_info": {
8
    "codemirror_mode": {
9
​
⇛⇚
xxxxxxxxxx
38
 
1
{
2
  "kernelspec": {
3
    "display_name": "Python [conda env:root] *",
4
    "language": "python",
5
    "name": "conda-root-py"
6
  },
7
  "language_info": {
8
​
38
}
xxxxxxxxxx
3
 
1
## TODO: 
2
- Normalize targets after clamping
3
    - Helps get training MSE down to ~0 when training on the first two samples
Metadata changed
xxxxxxxxxx

Notebook by Paul Scotti with code adapted from Aidan Dempster (https://github.com/Veldrovive/open_clip)

In particular, please somebody try out the various networks Aidan shared (https://github.com/Veldrovive/open_clip/blob/main/src/open_clip/model.py) which includes more complex architectures like transformers and architectures that handle both 2D and 3D voxels.

I also have a DistributedDataParallel version of this notebook for anyone who might want to use this with multi-gpu on Slurm (just ask me for it).

In [1]:
xxxxxxxxxx
1
 
1
!nvidia-smi
Metadata changed
xxxxxxxxxx
Wed Nov 16 16:59:13 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 515.65.01    Driver Version: 515.65.01    CUDA Version: 11.7     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  NVIDIA RTX A6000    Off  | 00000000:01:00.0 Off |                  Off |
| 30%   52C    P5    85W / 300W |      2MiB / 49140MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|  No running processes found                                                 |
+-----------------------------------------------------------------------------+
application/vnd.jupyter.stdout

Import packages & functions¶

In [ ]:
In [2]:
xxxxxxxxxx
2
 
1
# You will need to download files from huggingface and change the respective paths to those files
2
# https://huggingface.co/datasets/pscotti/naturalscenesdataset/tree/main
In [3]:
xxxxxxxxxx
2
 
1
#!pip install "git+https://github.com/openai/CLIP.git@main#egg=clip"
2
#!pip install git+https://github.com/openai/CLIP.git
Metadata changed
xxxxxxxxxx
In [4]:
xxxxxxxxxx
1
 
1
#!pip install info-nce-pytorch
Metadata changed
xxxxxxxxxx
In [51]:
xxxxxxxxxx
23
 
1
import os
2
import sys
3
import math
4
import numpy as np
5
import pandas as pd
6
from matplotlib import pyplot as plt
7
import seaborn as sns
8
sns.set(font_scale=1.0)
9
import torch
10
from torch import nn
11
import torchvision
12
from torchvision import transforms
13
from tqdm import tqdm
14
import PIL
15
from datetime import datetime
16
import h5py
17
import webdataset as wds
18
from info_nce import InfoNCE
19
import clip
20
import time
21
from collections import OrderedDict
22
from glob import glob
23
from PIL import Image
Metadata changed
xxxxxxxxxx
In [1]:
In [45]:
xxxxxxxxxx
57
 
1
import os
2
import sys
3
import math
4
import numpy as np
5
from matplotlib import pyplot as plt
6
import torch
7
from torch import nn
8
import torchvision
9
from torchvision import transforms
10
from tqdm import tqdm
11
import PIL
12
from datetime import datetime
13
import h5py
14
import webdataset as wds
15
from info_nce import InfoNCE
16
import clip
17
​
18
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
19
print(device)
20
​
21
mean=np.array([0.48145466, 0.4578275, 0.40821073])
22
std=np.array([0.26862954, 0.26130258, 0.27577711])
23
denorm = transforms.Normalize((-mean / std).tolist(), (1.0 / std).tolist())
24
​
25
def np_to_Image(x):
26
    return PIL.Image.fromarray((x.transpose(1, 2, 0)*127.5+128).clip(0,255).astype('uint8'))
27
def torch_to_Image(x,device=device):
28
    x = denorm(x)
29
    return transforms.ToPILImage()(x)
30
def Image_to_torch(x):
31
    return (transforms.ToTensor()(x[0])[:3].unsqueeze(0)-.5)/.5
32
def pairwise_cosine_similarity(A, B, dim=1, eps=1e-8):
33
    #https://stackoverflow.com/questions/67199317/pytorch-cosine-similarity-nxn-elements
34
​
37
    denominator = torch.max(torch.sqrt(torch.outer(A_l2, B_l2)), torch.tensor(eps))
38
    return torch.div(numerator, denominator)
39
def batchwise_cosine_similarity(Z,B):
40
    # https://www.h4pz.co/blog/2021/4/2/batch-cosine-similarity-in-pytorch-or-numpy-jax-cupy-etc
41
    B = B.T
42
    Z_norm = torch.linalg.norm(Z, dim=1, keepdim=True)  # Size (n, 1).
43
    B_norm = torch.linalg.norm(B, dim=0, keepdim=True)  # Size (1, b).
44
    cosine_similarity = ((Z @ B) / (Z_norm @ B_norm)).T
45
    return cosine_similarity
46
def get_non_diagonals(a):
47
    a = torch.triu(a,diagonal=1)+torch.tril(a,diagonal=-1)
48
    # make diagonals -1
49
    a=a.fill_diagonal_(-1)
50
    return a
51
def topk(similarities,labels,k=5):
52
    if k > similarities.shape[0]:
53
        k = similarities.shape[0]
54
    topsum=0
55
    for i in range(k):
56
        topsum += torch.sum(torch.argsort(similarities,axis=1)[:,-(i+1)] == labels)/len(labels)
57
    return topsum
⇛⇚
xxxxxxxxxx
55
 
1
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
2
print(device)
3
​
4
mean = np.array([0.48145466, 0.4578275, 0.40821073])
5
std = np.array([0.26862954, 0.26130258, 0.27577711])
6
denorm = transforms.Normalize((-mean / std).tolist(), (1.0 / std).tolist())
7
​
8
def np_to_Image(x):
9
    return PIL.Image.fromarray((x.transpose(1, 2, 0)*127.5+128).clip(0,255).astype('uint8'))
10
​
11
def torch_to_Image(x,device=device):
12
    x = denorm(x)
13
    return transforms.ToPILImage()(x)
14
​
15
def Image_to_torch(x):
16
    return (transforms.ToTensor()(x[0])[:3].unsqueeze(0)-.5)/.5
17
​
18
def pairwise_cosine_similarity(A, B, dim=1, eps=1e-8):
19
    #https://stackoverflow.com/questions/67199317/pytorch-cosine-similarity-nxn-elements
20
​
23
    denominator = torch.max(torch.sqrt(torch.outer(A_l2, B_l2)), torch.tensor(eps))
24
    return torch.div(numerator, denominator)
25
​
26
def batchwise_cosine_similarity(Z, B):
27
    # https://www.h4pz.co/blog/2021/4/2/batch-cosine-similarity-in-pytorch-or-numpy-jax-cupy-etc
28
    B = B.T
29
    Z_norm = torch.linalg.norm(Z, dim=1, keepdim=True)  # Size (n, 1).
30
    B_norm = torch.linalg.norm(B, dim=0, keepdim=True)  # Size (1, b).
31
    cosine_similarity = ((Z @ B) / (Z_norm @ B_norm)).T
32
    return cosine_similarity
33
​
34
def get_non_diagonals(a):
35
    a = torch.triu(a,diagonal=1)+torch.tril(a,diagonal=-1)
36
    # make diagonals -1
37
    a=a.fill_diagonal_(-1)
38
    return a
39
​
40
def topk(similarities,labels,k=5):
41
    if k > similarities.shape[0]:
42
        k = similarities.shape[0]
43
    topsum=0
44
    for i in range(k):
45
        topsum += torch.sum(torch.argsort(similarities, axis=1)[:,-(i+1)] == labels)/len(labels)
46
    return topsum
47
​
48
def get_preprocs():
49
    preproc_vox = transforms.Compose([transforms.ToTensor(), torch.nan_to_num])
50
    preproc_img = transforms.Compose([
51
                        transforms.Resize(size=(224,224)),
52
                        transforms.Normalize(mean=mean,
53
                                             std=std),
54
                    ])
55
    return preproc_vox, preproc_img
Outputs unchanged
cuda
application/vnd.jupyter.stdout

Which pretrained model are you using for voxel alignment to embedding space?¶

In [2]:
In [8]:
xxxxxxxxxx
7
 
1
model_name = 'clip_image_vit' # CLIP ViT-L/14 image embeddings
2
​
3
# model_name = 'clip_text_vit' # CLIP ViT-L/14 text embeddings
4
​
5
# model_name = 'clip_image_resnet' # CLIP basic ResNet image embeddings
6
​
7
print(f"Using model: {model_name}")
Outputs unchanged
Using model: clip_image_vit
application/vnd.jupyter.stdout
In [3]:
In [9]:
xxxxxxxxxx
51
 
1
​
11
# dont want to train model
12
model.eval()
13
# dont need to calculate gradients
14
for param in model.parameters():
15
    param.requires_grad = False
16
​
17
if model_name=='clip_text_vit':
18
    f = h5py.File('/scratch/gpfs/KNORMAN/nsdgeneral_hdf5/COCO_73k_subj_indices.hdf5', 'r')
19
    subj01_order = f['subj01'][:]
20
​
46
            image_features = model.encode_image(image.to(device))
47
            if "vit" in model_name: # I think this is the clamping used by Lin Sprague Singh preprint
48
                image_features = torch.clamp(image_features,-1.5,1.5) 
49
        return image_features     
50
    
51
#print(model)
⇛⇚
xxxxxxxxxx
54
 
1
​
11
# dont want to train model
12
model.eval()
13
​
14
# dont need to calculate gradients
15
for param in model.parameters():
16
    param.requires_grad = False
17
​
18
if model_name == 'clip_text_vit':
19
    f = h5py.File('/scratch/gpfs/KNORMAN/nsdgeneral_hdf5/COCO_73k_subj_indices.hdf5', 'r')
20
    subj01_order = f['subj01'][:]
21
​
47
            image_features = model.encode_image(image.to(device))
48
            if "vit" in model_name: # I think this is the clamping used by Lin Sprague Singh preprint
49
                image_features = torch.clamp(image_features, -1.5, 1.5)
50
                # normalize after clipping per the paper
51
                image_features = nn.functional.normalize(image_features, dim=-1)
52
        return image_features
53
    
54
#print(model)

Load data¶

NSD webdatasets for subjects 1, 2, and 3 are publicly available here:

https://huggingface.co/datasets/pscotti/naturalscenesdataset/tree/main/webdataset

In [4]:
In [10]:
xxxxxxxxxx
3
 
1
# use large batches and the complete training dataset? 
2
full_training = True
3
print('full_training',full_training)
Outputs unchanged
full_training True
application/vnd.jupyter.stdout
In [11]:
xxxxxxxxxx
12
 
1
# NAT_SCENE = "/scratch/gpfs/KNORMAN/webdataset_nsd/webdataset_split/"
2
NAT_SCENE = "/home/jimgoo/data/neuro/naturalscenesdataset/webdataset/"
3
​
4
# the tar files have a slightly different format
5
if "/scratch/gpfs/KNORMAN" in NAT_SCENE:
6
    SUBJ_FORMAT = "train_subj01_{{{}..{}}}.tar"    
7
    SUBJ_FORMAT_VAL = "val_subj01_0.tar"
8
    VOXELS_KEY = 'nsdgeneral.npy'
9
else:
10
    SUBJ_FORMAT = "subj01_nsdgeneral_{{{}..{}}}.tar"
11
    SUBJ_FORMAT_VAL = "val_subj01_nsdgeneral_0.tar"
12
    VOXELS_KEY = 'voxel.npy'
Metadata changed
xxxxxxxxxx
In [12]:
xxxxxxxxxx
1
 
1
SUBJ_FORMAT.format(0, 1)
Metadata changed
xxxxxxxxxx
'subj01_nsdgeneral_{0..1}.tar'
text/plain
In [13]:
xxxxxxxxxx
98
 
1
## things in one sample of data:
2
# sample00000.voxel.npy
3
# sample00000.voxel_3d.npy
4
# sample00000.trial.npy
5
# sample00000.sgxl_emb.npy
6
# sample00000.jpg
7
​
8
preproc_vox, preproc_img = get_preprocs()
9
​
10
# <TODO> check augmentation results before forward pass
11
# image augmentation just for the CLIP image model that will be more semantic-focused
12
# img_augment = transforms.Compose([
13
#                     transforms.RandomCrop(size=(140,140)),
14
#                     transforms.Resize(size=(224,224)),
15
#                     transforms.RandomHorizontalFlip(p=.5),
16
#                     transforms.ColorJitter(.4,.4,.2,.1),
17
#                     transforms.RandomGrayscale(p=.2),
18
#                 ])
19
​
20
# <TODO> try more things
21
img_augment = transforms.Compose([
22
                    transforms.Resize(size=(224,224)),
23
                ])
24
​
25
if not full_training: 
26
    num_devices = 1
27
    num_workers = 4
28
    print("num_workers", num_workers)
29
    batch_size = 16
30
    print("batch_size", batch_size)
31
    num_samples = 500 
32
    global_batch_size = batch_size * num_devices
33
    print("global_batch_size", global_batch_size)
34
    num_batches = math.floor(num_samples / global_batch_size)
35
    num_worker_batches = math.floor(num_batches / num_workers)
36
    print("num_worker_batches", num_worker_batches)
37
    train_url = os.path.join(NAT_SCENE, "train", SUBJ_FORMAT.format(0, 1))
38
    
39
else:
40
    # num_devices = torch.cuda.device_count()
41
    num_devices = 1
42
    print("WARNING: num_devices hardcoded")
43
    print("num_devices", num_devices)
44
    # num_workers = num_devices * 4
45
    num_workers = 1 # <TODO> switch back the above
46
    print("WARNING num_workers hardcoded")
47
    print("num_workers", num_workers)
48
    batch_size = 300
49
    # batch_size = 1
50
    # print("WARNING tiny batch size")
51
    
52
    print("batch_size",batch_size)
53
    num_samples = 24983 # see metadata.json in webdataset_split folder
54
    global_batch_size = batch_size * num_devices
55
    print("global_batch_size", global_batch_size)
56
    num_batches = math.floor(num_samples / global_batch_size)
57
    num_worker_batches = math.floor(num_batches / num_workers)
58
    print("num_worker_batches", num_worker_batches)
59
    train_url = os.path.join(NAT_SCENE, "train", SUBJ_FORMAT.format(0, 49))
60
​
61
train_data = wds.DataPipeline([
62
                    # wds.ResampledShards(train_url), # <TODO> switch back to this once I understand it
63
                    wds.SimpleShardList(train_url),
64
                    wds.tarfile_to_samples(),
65
                    # wds.shuffle(500, initial=500), # <TODO> this seems hardcoded for `full_training=False`
66
                    wds.decode("torch"),
67
                    #wds.rename(images="jpg;png", voxels=VOXELS_KEY, embs="sgxl_emb.npy", trial="trial.npy"),
68
                    wds.rename(images="jpg;png", voxels=VOXELS_KEY), # <TODO> use less-lean version above
69
                    wds.map_dict(images=preproc_img),
70
                    wds.to_tuple("voxels", emb_name),
71
                    wds.batched(batch_size, partial=True),
72
                ]) #.with_epoch(num_worker_batches) # <TODO> add this back
73
​
74
train_dl = wds.WebLoader(train_data, num_workers=num_workers,
75
                         batch_size=None, shuffle=False, persistent_workers=True)
76
​
77
# Validation #
78
num_samples = 492
79
num_batches = math.ceil(num_samples / global_batch_size)
80
num_worker_batches = math.ceil(num_batches / num_workers)
81
print("validation: num_worker_batches", num_worker_batches)
82
​
83
url = os.path.join(NAT_SCENE, "val", SUBJ_FORMAT_VAL)
84
​
85
val_data = wds.DataPipeline([
86
                    # wds.ResampledShards(url), # <TODO> switch back to this once I understand it
87
                    wds.SimpleShardList(url),
88
                    wds.tarfile_to_samples(),
89
                    wds.decode("torch"),
90
                    # wds.rename(images="jpg;png", voxels=VOXELS_KEY, embs="sgxl_emb.npy", trial="trial.npy"),
91
                    wds.rename(images="jpg;png", voxels=VOXELS_KEY), # <TODO> use less-lean version above
92
                    wds.map_dict(images=preproc_img),
93
                    wds.to_tuple("voxels", emb_name),
94
                    wds.batched(batch_size, partial=True),
95
                ])#.with_epoch(num_worker_batches) # <TODO> add this back
96
​
97
val_dl = wds.WebLoader(val_data, num_workers=num_workers,
98
                       batch_size=None, shuffle=False, persistent_workers=True)
Metadata changed
xxxxxxxxxx
WARNING: num_devices hardcoded
num_devices 1
WARNING num_workers hardcoded
num_workers 1
batch_size 300
global_batch_size 300
num_worker_batches 83
validation: num_worker_batches 2
application/vnd.jupyter.stdout
In [16]:
xxxxxxxxxx
18
 
1
def test_loader(dl):
2
    # run through one batch and verify things are working
3
    for i, (voxel, emb) in enumerate(dl):
4
        print("idx", i)
5
        print("voxel.shape", voxel.shape)
6
        print("emb.shape", emb.shape)
7
        
8
        if emb_name=='images': # image embedding
9
            emb = emb.to(device)
10
        else: # text embedding
11
            text_tokens = text_tokenize(subj01_annots[emb]).to(device)
12
        
13
        emb = embedder(emb)
14
        print("emb.shape2", emb.shape)
15
        out_dim = emb.shape[1]
16
        print("out_dim", out_dim)
17
        break
18
    return out_dim
Metadata changed
xxxxxxxxxx
In [17]:
xxxxxxxxxx
1
 
1
out_dim = test_loader(train_dl)
Metadata changed
xxxxxxxxxx
idx 0
voxel.shape torch.Size([300, 15724])
emb.shape torch.Size([300, 3, 224, 224])
emb.shape2 torch.Size([300, 768])
out_dim 768
application/vnd.jupyter.stdout
In [18]:
xxxxxxxxxx
1
 
1
out_dim = test_loader(val_dl)
Metadata changed
xxxxxxxxxx
idx 0
voxel.shape torch.Size([300, 15724])
emb.shape torch.Size([300, 3, 224, 224])
emb.shape2 torch.Size([300, 768])
out_dim 768
application/vnd.jupyter.stdout
In [19]:
xxxxxxxxxx
8
 
1
# t0 = time.time()
2
# n_batches = 0
3
# for train_i, (voxel0, emb0) in enumerate(train_dl):
4
#     n_batches += 1
5
# t1 = time.time()
6
​
7
# # 84, 233.06136536598206
8
# n_batches, t1-t0
Metadata changed
xxxxxxxxxx
In [20]:
xxxxxxxxxx
8
 
1
# t0 = time.time()
2
# n_batches = 0
3
# for val_i, (val_voxel0, val_emb0) in enumerate(val_dl):
4
#     n_batches += 1
5
# t1 = time.time()
6
​
7
# # (492, 3.9010021686553955)
8
# n_batches, t1-t0
Metadata changed
xxxxxxxxxx
In [21]:
xxxxxxxxxx
6
 
1
# get the first batch of everything
2
for train_i, (voxel0, emb0) in enumerate(train_dl):
3
    break
4
​
5
for val_i, (val_voxel0, val_emb0) in enumerate(val_dl):
6
    break
Metadata changed
xxxxxxxxxx
In [22]:
xxxxxxxxxx
1
 
1
voxel0.shape, val_voxel0.shape
Metadata changed
xxxxxxxxxx
(torch.Size([300, 15724]), torch.Size([300, 15724]))
text/plain
In [23]:
xxxxxxxxxx
1
 
1
emb0.shape, val_emb0.shape
Metadata changed
xxxxxxxxxx
(torch.Size([300, 3, 224, 224]), torch.Size([300, 3, 224, 224]))
text/plain
In [24]:
xxxxxxxxxx
1
 
1
torch_to_Image(emb0[0])
Metadata changed
xxxxxxxxxx
In [25]:
xxxxxxxxxx
1
 
1
torch_to_Image(val_emb0[0])
Metadata changed
xxxxxxxxxx
In [26]:
xxxxxxxxxx
6
 
1
# <TODO> scale the voxels once I understand more about the format of trials inside the tar dataset files
2
V = voxel0.cpu().numpy()
3
plt.plot(np.vstack((np.max(V, 0), np.mean(V, 0), np.min(V, 0))).T);
4
plt.legend(['max', 'mean', 'min']);
5
plt.xlabel('position in flattened voxel array');
6
plt.ylabel('voxel value');
Metadata changed
xxxxxxxxxx
In [36]:
xxxxxxxxxx
90
 
1
preproc_vox = transforms.Compose([transforms.ToTensor(),torch.nan_to_num])
2
​
3
preproc_img = transforms.Compose([
4
                    transforms.Resize(size=(224,224)),
5
                    transforms.Normalize(mean=mean,
6
                                         std=std),
7
                ])
8
​
9
# image augmentation just for the CLIP image model that will be more semantic-focused
10
img_augment = transforms.Compose([
11
                    transforms.RandomCrop(size=(140,140)),
12
                    transforms.Resize(size=(224,224)),
13
                    transforms.RandomHorizontalFlip(p=.5),
14
                    transforms.ColorJitter(.4,.4,.2,.1),
15
                    transforms.RandomGrayscale(p=.2),
16
                ])
17
​
18
if not full_training: 
19
    num_devices = 1
20
    num_workers = 4
21
    print("num_workers",num_workers)
22
    batch_size = 16
23
    print("batch_size",batch_size)
24
    num_samples = 500 
25
    global_batch_size = batch_size * num_devices
26
    print("global_batch_size",global_batch_size)
27
    num_batches = math.floor(num_samples / global_batch_size)
28
    num_worker_batches = math.floor(num_batches / num_workers)
29
    print("num_worker_batches",num_worker_batches)
30
    train_url = "/scratch/gpfs/KNORMAN/webdataset_nsd/webdataset_split/train/train_subj01_{0..1}.tar"
31
else:
32
    num_devices = torch.cuda.device_count()
33
    print("num_devices",num_devices)
34
    num_workers = num_devices * 4
35
    print("num_workers",num_workers)
36
    batch_size = 300
37
    print("batch_size",batch_size)
38
    num_samples = 24983 # see metadata.json in webdataset_split folder
39
    global_batch_size = batch_size * num_devices
40
    print("global_batch_size",global_batch_size)
41
    num_batches = math.floor(num_samples / global_batch_size)
42
    num_worker_batches = math.floor(num_batches / num_workers)
43
    print("num_worker_batches",num_worker_batches)
44
    train_url = "/scratch/gpfs/KNORMAN/webdataset_nsd/webdataset_split/train/train_subj01_{0..49}.tar"
45
​
46
train_data = wds.DataPipeline([wds.ResampledShards(train_url),
47
                    wds.tarfile_to_samples(),
48
                    wds.shuffle(500,initial=500),
49
                    wds.decode("torch"),
50
                    wds.rename(images="jpg;png", voxels="nsdgeneral.npy", embs="sgxl_emb.npy", trial="trial.npy"),
51
                    wds.map_dict(images=preproc_img),
52
                    wds.to_tuple("voxels", emb_name),
53
                    wds.batched(batch_size, partial=True),
54
                ]).with_epoch(num_worker_batches)
55
train_dl = wds.WebLoader(train_data, num_workers=num_workers,
56
                         batch_size=None, shuffle=False, persistent_workers=True)
57
​
58
# Validation #
59
num_samples = 492
60
num_batches = math.ceil(num_samples / global_batch_size)
61
num_worker_batches = math.ceil(num_batches / num_workers)
62
print("validation: num_worker_batches",num_worker_batches)
63
​
64
url = "/scratch/gpfs/KNORMAN/webdataset_nsd/webdataset_split/val/val_subj01_0.tar"
65
val_data = wds.DataPipeline([wds.ResampledShards(url),
66
                    wds.tarfile_to_samples(),
67
                    wds.decode("torch"),
68
                    wds.rename(images="jpg;png", voxels="nsdgeneral.npy", 
69
                                embs="sgxl_emb.npy", trial="trial.npy"),
70
                    wds.map_dict(images=preproc_img),
71
                    wds.to_tuple("voxels", emb_name),
72
                    wds.batched(batch_size, partial=True),
73
                ]).with_epoch(num_worker_batches)
74
val_dl = wds.WebLoader(val_data, num_workers=num_workers,
75
                       batch_size=None, shuffle=False, persistent_workers=True)
76
​
77
# check that your data loaders are working
78
for train_i, (voxel, emb) in enumerate(train_dl):
79
    print("idx",train_i)
80
    print("voxel.shape",voxel.shape)
81
    if emb_name=='images': # image embedding
82
        emb = emb.to(device)
83
    else: # text embedding
84
        text_tokens = text_tokenize(subj01_annots[emb]).to(device)
85
    print("emb.shape",emb.shape)
86
    emb = embedder(emb)
87
    print("emb.shape",emb.shape)
88
    out_dim = emb.shape[1]
89
    print("out_dim", out_dim)
90
    break
Metadata changed
xxxxxxxxxx
num_devices 1
num_workers 4
batch_size 300
global_batch_size 300
num_worker_batches 20
validation: num_worker_batches 1
idx 0
voxel.shape torch.Size([300, 15724])
emb.shape torch.Size([300, 3, 224, 224])
emb.shape torch.Size([300, 768])
out_dim 768
application/vnd.jupyter.stdout

Initialize network¶

In [6]:
In [30]:
xxxxxxxxxx
57
 
1
class BrainNetwork(nn.Module):
2
    def __init__(self, out_dim, h=7861):
3
        super().__init__()
4
        self.conv = nn.Sequential(
5
            nn.Conv1d(1, 32, kernel_size=3, stride=1, padding=0),
6
            nn.Dropout1d(0.1),
7
            nn.ReLU(),
8
            nn.MaxPool1d(kernel_size=2, stride=2)
9
        )
10
        self.lin = nn.Linear(h,h)
11
        self.relu = nn.ReLU()
12
        self.lin1 = nn.Linear(251552,out_dim)
13
        
14
    def forward(self, x):
15
        x = x[:,None,:]
16
        x = self.conv(x)
17
        residual = x
18
        for res_block in range(4):
19
            x = self.lin(x)
20
            x += residual
21
            x = self.relu(x)
22
            residual = x
23
        x = x.reshape(len(x),-1)
24
        x = self.lin1(x)
25
        return x
26
​
27
# PS note: i also tried the below network and it didn't work nearly as good at the top one
28
​
56
#         x = self.lin1(x)
57
#         return x 
⇛⇚
xxxxxxxxxx
95
 
1
# class BrainNetwork(nn.Module):
2
#     def __init__(self, out_dim, h=7861):
3
#         super().__init__()
4
#         self.conv = nn.Sequential(
5
#             nn.Conv1d(1, 32, kernel_size=3, stride=1, padding=0),
6
#             nn.Dropout1d(0.1),
7
#             nn.ReLU(),
8
#             nn.MaxPool1d(kernel_size=2, stride=2)
9
#         )
10
#         self.lin = nn.Linear(h, h)
11
#         self.relu = nn.ReLU()
12
#         self.lin1 = nn.Linear(251552, out_dim)
13
        
14
#     def forward(self, x):
15
#         #import ipdb; ipdb.set_trace()
16
#         # [300, 15724] -> [300, 1, 15724]
17
#         x = x[:, None, :]
18
        
19
#         # [300, 1, 15724] -> [300, 32, 7861]
20
#         x = self.conv(x)
21
#         residual = x
22
#         for res_block in range(4):
23
#             # same output shape
24
#             x = self.lin(x)
25
#             x += residual
26
#             x = self.relu(x)
27
#             residual = x
28
#         # [300, 32, 7861] -> [300, 251552]
29
#         x = x.reshape(len(x), -1)
30
#         x = self.lin1(x)
31
#         return x
32
​
33
# PS note: i also tried the below network and it didn't work nearly as good at the top one
34
​
62
#         x = self.lin1(x)
63
#         return x 
64
​
65
class BrainNetwork(nn.Module):
66
    def __init__(self,
67
                 out_dim,
68
                 input_size=15724,
69
                 h1=4096,
70
                 h2=2048,
71
                 h3=1024,
72
                 pdrop=0.1,
73
    ):
74
        super().__init__()
75
        
76
        self.mlp = nn.Sequential(
77
            #torch.nn.BatchNorm1d(input_size),
78
            nn.Linear(input_size, h1),
79
            nn.ReLU(),
80
            nn.Dropout(pdrop),
81
            nn.Linear(h1, h2),
82
            nn.ReLU(),
83
            nn.Dropout(pdrop),
84
            nn.Linear(h2, h3),
85
            nn.ReLU(),
86
            nn.Dropout(pdrop),
87
            nn.Linear(h3, out_dim),
88
        )
89
        
90
    def forward(self, x):
91
        return self.mlp(x)
92
​
93
def param_count(model):
94
    """number of params in model"""
95
    return sum(p.numel() for p in model.parameters() if p.requires_grad)
In [34]:
xxxxxxxxxx
39
 
1
# reset rng seed
2
torch.manual_seed(123)
3
np.random.seed(123)
4
​
5
# init model
6
brain_net = BrainNetwork(out_dim)
7
​
8
# input_size = 15724
9
# h1 = 4096
10
# h2 = 2048
11
# h3 = 1024
12
# pdrop = 0.1
13
​
14
# brain_net = nn.Sequential(
15
#     #torch.nn.BatchNorm1d(input_size),
16
#     nn.Linear(input_size, h1),
17
#     nn.ReLU(),
18
#     nn.Linear(h1, h2),
19
#     nn.ReLU(),
20
#     nn.Linear(h2, h3),
21
#     nn.ReLU(),
22
#     nn.Linear(h3, out_dim),
23
# )
24
# brain_net = nn.Sequential(
25
#     #torch.nn.BatchNorm1d(input_size),
26
#     nn.Linear(input_size, h1),
27
#     nn.ReLU(),
28
#     nn.Dropout(pdrop),
29
#     nn.Linear(h1, h2),
30
#     nn.ReLU(),
31
#     nn.Dropout(pdrop),
32
#     nn.Linear(h2, h3),
33
#     nn.ReLU(),
34
#     nn.Dropout(pdrop),
35
#     nn.Linear(h3, out_dim),
36
# )
37
​
38
print("{:,} params".format(param_count(brain_net)))
39
brain_net
Metadata changed
xxxxxxxxxx
75,685,632 params
application/vnd.jupyter.stdout
BrainNetwork(
  (mlp): Sequential(
    (0): Linear(in_features=15724, out_features=4096, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.1, inplace=False)
    (3): Linear(in_features=4096, out_features=2048, bias=True)
    (4): ReLU()
    (5): Dropout(p=0.1, inplace=False)
    (6): Linear(in_features=2048, out_features=1024, bias=True)
    (7): ReLU()
    (8): Dropout(p=0.1, inplace=False)
    (9): Linear(in_features=1024, out_features=768, bias=True)
  )
)
text/plain
In [7]:
In [35]:
xxxxxxxxxx
12
 
1
# reset rng seed
2
torch.manual_seed(123)
3
np.random.seed(123)
4
​
5
# init model
6
brain_net = BrainNetwork(out_dim)
7
brain_net = brain_net.to(device)
8
​
9
# test out that the neural network can run without error:
10
with torch.cuda.amp.autocast():
11
    out = brain_net(voxel.to(device))
12
    print(out.shape)
⇛⇚
xxxxxxxxxx
6
 
1
brain_net = brain_net.to(device)
2
​
3
# test out that the neural network can run without error:
4
with torch.cuda.amp.autocast():
5
    out = brain_net(voxel0.to(device))
6
    print(out.shape)
Outputs unchanged
torch.Size([300, 768])
application/vnd.jupyter.stdout

Train model¶

In [8]:
In [36]:
xxxxxxxxxx
24
 
1
if full_training:
2
    num_epochs = 100
3
else:
4
    num_epochs = 20
5
​
6
initial_learning_rate = 1e-6
7
optimizer = torch.optim.AdamW(brain_net.parameters(), lr=initial_learning_rate)
8
# optimizer = torch.optim.SGD(brain_net.parameters(), lr=initial_learning_rate, momentum=0.95)
9
​
10
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, min_lr=1e-8, patience=5) 
11
nce = InfoNCE() # what we will use for loss function  
12
​
13
# Other losses to consider: #
14
​
⇛⇚
xxxxxxxxxx
33
 
1
if full_training:
2
    num_epochs = 100
3
else:
4
    num_epochs = 20
5
    
6
#initial_learning_rate = 1e-6
7
initial_learning_rate = 3e-4
8
# initial_learning_rate = 0.01
9
#initial_learning_rate = 3e-3
10
#print("WARNING - large learning rate", initial_learning_rate)
11
​
12
optimizer = torch.optim.Adam(brain_net.parameters(), lr=initial_learning_rate)
13
# optimizer = torch.optim.SGD(brain_net.parameters(), lr=initial_learning_rate, momentum=0.95)
14
​
15
#scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, min_lr=1e-8, patience=5)
16
# scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5)
17
​
18
loss_fun = InfoNCE() # what we will use for loss function  
19
# loss_fun = nn.MSELoss()
20
#loss_fun = nn.MSELoss()
21
​
22
# Other losses to consider: #
23
​
In [37]:
xxxxxxxxxx
48
 
1
def plot_training():
2
    print(f"num_epochs:{num_epochs} batch_size:{batch_size} lr:{initial_learning_rate}")
3
    fig, (ax1, ax2, ax3, ax4) = plt.subplots(1, 4, figsize=(17,3))
4
    ax1.set_title(f"Training Loss\n(final={train_losses[-1]})")
5
    ax1.plot(train_losses)
6
    ax2.set_title(f"Training Performance\n(final={train_percent_correct[-1]})")
7
    ax2.plot(train_percent_correct)
8
    ax3.set_title(f"Val Loss\n(final={val_losses[-1]})")
9
    ax3.plot(val_losses)
10
    ax4.set_title(f"Val Performance\n(final={val_percent_correct[-1]})")
11
    ax4.plot(val_percent_correct)
12
    plt.show()
13
    
14
def plot_preds(y, y_hat, title='', outdir=''):
15
    true = y.cpu().detach().numpy().T
16
    pred = y_hat.cpu().detach().numpy().T
17
    
18
    for i in range(y.shape[0]):
19
        plt.plot(np.vstack((true[:,i], pred[:,i])).T);
20
        plt.legend(['true', 'pred']);
21
        plt.title(title + ' sample %i' % i)
22
        if outdir:
23
            if not os.path.exists(outdir):
24
                os.makedirs(outdir)
25
            plt.savefig(outdir + '/%s-preds-sample-%i.jpeg' % (title, i))
26
            plt.close()
27
        else:
28
            plt.show()
29
        
30
def plot_err(y, y_hat, title=''):
31
    err = (y - y_hat)
32
    err = err.cpu().detach().numpy()
33
    plt.plot(err.T)
34
    plt.title(title)
35
    plt.show();    
36
    
37
class AverageMeter:
38
    def __init__(self, name=None):
39
        self.name = name
40
        self.reset()
41
​
42
    def reset(self):
43
        self.sum = self.count = self.avg = 0
44
​
45
    def update(self, val, n=1):
46
        self.sum += val * n
47
        self.count += n
48
        self.avg = self.sum / self.count    
Metadata changed
xxxxxxxxxx
In [38]:
xxxxxxxxxx
1
 
1
outdir = './checkpoints/v01'
Metadata changed
xxxxxxxxxx
In [39]:
xxxxxxxxxx
1
 
1
!rm -rf $outdir/*
Metadata changed
xxxxxxxxxx
In [40]:
xxxxxxxxxx
3
 
1
!mkdir -p $outdir
2
!mkdir -p $outdir/preds/train/
3
!mkdir -p $outdir/preds/val/
Metadata changed
xxxxxxxxxx
In [41]:
xxxxxxxxxx
1
 
1
!tree $outdir/
Metadata changed
xxxxxxxxxx
./checkpoints/v01/
└── preds
    ├── train
    └── val

3 directories, 0 files
application/vnd.jupyter.stdout
In [9]:
In [42]:
xxxxxxxxxx
103
 
1
print(f"num_epochs:{num_epochs} batch_size:{batch_size} lr:{initial_learning_rate}")
2
print(datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
3
print(f"Will be saving model checkpoints to checkpoints/{model_name}_subj01_epoch#.pth")
4
​
5
epoch = 0
6
train_losses = []; val_losses = []
7
train_percent_correct = []
8
val_percent_correct = []
9
lrs = []
10
​
11
# # resuming from checkpoint?
12
​
21
# lrs=checkpoint['lrs']
22
​
23
pbar = tqdm(range(epoch,num_epochs))
24
for epoch in pbar:
25
    brain_net.train()
26
    similarities = []
27
    for train_i, (voxel, emb) in enumerate(train_dl):
28
        optimizer.zero_grad()
29
        
30
        voxel = voxel.to(device)
31
        
32
        with torch.cuda.amp.autocast():
33
            if emb_name=='images': # image embedding
34
​
43
            if torch.any(torch.isnan(emb_)):
44
                raise ValueError("NaN found...")
45
            emb_ = nn.functional.normalize(emb_,dim=-1) # l2 normalization on the embeddings
46
            
47
            labels = torch.arange(len(emb)).to(device)
48
            loss = nce(emb_.reshape(len(emb),-1),emb.reshape(len(emb),-1))
49
            
50
            similarities = batchwise_cosine_similarity(emb,emb_)
51
​
52
            percent_correct = topk(similarities,labels,k=1)
53
            
54
        loss.backward()
55
        optimizer.step()
56
        
57
        train_losses.append(loss.item())
58
        train_percent_correct.append(percent_correct.item())
59
        
60
    brain_net.eval()    
61
    # using all validation samples to compute loss
62
    for val_i, (val_voxel, val_emb) in enumerate(val_dl):
63
        with torch.no_grad(): 
64
            val_voxel = val_voxel.to(device)
65
            
66
            with torch.cuda.amp.autocast():
67
                if emb_name=='images': # image embedding
68
​
72
​
73
                val_emb_ = brain_net(val_voxel)
74
            
75
                labels = torch.arange(len(val_emb)).to(device)
76
​
77
                val_loss = nce(val_emb_.reshape(len(val_emb),-1),val_emb.reshape(len(val_emb),-1))
78
​
79
                val_similarities = batchwise_cosine_similarity(val_emb,val_emb_)
80
​
81
                percent_correct = topk(val_similarities,labels,k=1)
82
                
83
            val_losses.append(val_loss.item())
84
            val_percent_correct.append(percent_correct.item())
85
                
86
    if epoch%5==0 and full_training:
87
        torch.save({
88
            'epoch': epoch,
89
            'model_state_dict': brain_net.state_dict(),
90
            'optimizer_state_dict': optimizer.state_dict(),
91
            'train_losses': train_losses,
92
            'val_losses': val_losses,
93
            'train_percent_correct': train_percent_correct,
94
            'val_percent_correct': val_percent_correct,
95
            'lrs': lrs,
96
            }, f'checkpoints/{model_name}_subj01_epoch{epoch}.pth')
97
            
98
    scheduler.step(val_loss) 
99
    lrs.append(optimizer.param_groups[0]['lr'])
100
    
101
    pbar.set_description(f"Loss: {np.median(train_losses[-(train_i+1):]):.3f} | VLoss: {np.median(val_losses[-(val_i+1):]):.3f}  | TopK%: {np.median(train_percent_correct[-10:]):.3f} | VTopK%: {np.median(val_percent_correct[-10:]):.3f} | lr{lrs[-1]:.5f}")
102
    
103
print(datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
97
            
⇛⇚
xxxxxxxxxx
159
 
1
bs = 300
2
​
3
print(f"num_epochs:{num_epochs} batch_size:{batch_size} lr:{initial_learning_rate}")
4
print(datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
5
print(f"Will be saving model checkpoints to checkpoints/{model_name}_subj01_epoch#.pth")
6
​
7
if not os.path.exists("checkpoints"):
8
    os.makedirs("checkpoints")
9
​
10
epoch = 0
11
train_losses = []; val_losses = []
12
train_percent_correct = []
13
val_percent_correct = []
14
lrs = []
15
epoch_logs = []
16
​
17
# # resuming from checkpoint?
18
​
27
# lrs=checkpoint['lrs']
28
​
29
pbar = tqdm(range(epoch, num_epochs))
30
for epoch in pbar:
31
    
32
    brain_net.train()
33
    
34
    train_loss_avg = AverageMeter()
35
    train_topk_avg = AverageMeter()
36
    val_loss_avg = AverageMeter()
37
    val_topk_avg = AverageMeter()
38
​
39
    for train_i, (voxel, emb) in enumerate(train_dl):
40
    #for train_i, (voxel, emb) in enumerate([(voxel0[:bs], emb0[:bs])]):
41
        # voxel = voxel0
42
        # emb = emb0
43
        
44
        bsz = voxel.shape[0]
45
​
46
        voxel = voxel.to(device)
47
​
48
        with torch.cuda.amp.autocast():
49
            if emb_name=='images': # image embedding
50
​
59
            if torch.any(torch.isnan(emb_)):
60
                raise ValueError("NaN found...")
61
            emb_ = nn.functional.normalize(emb_, dim=-1) # l2 normalization on the embeddings
62
​
63
            labels = torch.arange(bsz).to(device)
64
            loss = loss_fun(emb_.reshape(bsz, -1), emb.reshape(bsz, -1))
65
​
66
            similarities = batchwise_cosine_similarity(emb, emb_)
67
​
68
            percent_correct = topk(similarities, labels, k=1)
69
​
70
        optimizer.zero_grad()
71
        loss.backward()
72
        optimizer.step()
73
        
74
        train_losses.append(loss.item())
75
        train_percent_correct.append(percent_correct.item())
76
        
77
        train_loss_avg.update(loss.detach_(), bsz)
78
        train_topk_avg.update(percent_correct.detach_(), bsz)
79
        
80
        if train_i == 0 and epoch % 5 == 0:
81
            # plot_preds(emb[:2], emb_[:2], 'train', outdir + '/preds/epoch-%03d' % epoch)
82
            torch.save((emb_[:2], emb[:2]), outdir + '/preds/train/epoch-%03d.to' % epoch)
83
        
84
        # if train_i >= 0:
85
        #     break
86
​
87
    brain_net.eval()    
88
    
89
    # using all validation samples to compute loss
90
    for val_i, (val_voxel, val_emb) in enumerate(val_dl):
91
    #for val_i, (val_voxel, val_emb) in enumerate([(val_voxel0[:bs], val_emb0[:bs])]):
92
        # val_voxel = val_voxel0
93
        # val_emb = val_emb0
94
        
95
        bsz = val_voxel.shape[0]
96
​
97
        with torch.no_grad(): 
98
            val_voxel = val_voxel.to(device)
99
​
100
            with torch.cuda.amp.autocast():
101
                if emb_name=='images': # image embedding
102
​
106
​
107
                val_emb_ = brain_net(val_voxel)
108
                val_emb_ = nn.functional.normalize(val_emb_, dim=-1) # l2 normalization on the embeddings
109
​
110
                labels = torch.arange(bsz).to(device)
111
                val_loss = loss_fun(val_emb_.reshape(bsz,-1), val_emb.reshape(bsz,-1))
112
​
113
                val_similarities = batchwise_cosine_similarity(val_emb, val_emb_)
114
​
115
                percent_correct = topk(val_similarities, labels, k=1)
116
​
117
            val_losses.append(val_loss.item())
118
            val_percent_correct.append(percent_correct.item())
119
            
120
            val_loss_avg.update(val_loss.detach_(), bsz)
121
            val_topk_avg.update(percent_correct.detach_(), bsz)
122
            
123
            if val_i == 0 and epoch % 5 == 0:
124
                # plot_preds(val_emb[:2], val_emb_[:2], 'val', outdir + '/preds/epoch-%03d' % epoch)
125
                torch.save((val_emb_[:2], val_emb[:2]), outdir + '/preds/val/epoch-%03d.to' % epoch)
126
            
127
        # if val_i >= 0:
128
        #     break
129
​
130
#     if epoch % 5 == 0 and full_training:
131
#         torch.save({
132
#             'epoch': epoch,
133
#             'model_state_dict': brain_net.state_dict(),
134
#             'optimizer_state_dict': optimizer.state_dict(),
135
#             'train_losses': train_losses,
136
#             'val_losses': val_losses,
137
#             'train_percent_correct': train_percent_correct,
138
#             'val_percent_correct': val_percent_correct,
139
#             'lrs': lrs,
140
#             }, f'checkpoints/{model_name}_subj01_epoch{epoch}.pth')
141
    
142
    # <TODO> add back LR decay
143
    # scheduler.step(val_loss)
144
    
145
    lrs.append(optimizer.param_groups[0]['lr'])
146
    
147
    #pbar.set_description(f"Loss: {np.median(train_losses[-(train_i+1):]):.3f} | VLoss: {np.median(val_losses[-(val_i+1):]):.3f}  | TopK%: {np.median(train_percent_correct[-10:]):.3f} | VTopK%: {np.median(val_percent_correct[-10:]):.3f} | lr{lrs[-1]:.5f}")
148
    logs = OrderedDict(
149
        loss=train_loss_avg.avg.item(),
150
        topk=train_topk_avg.avg.item(),
151
        val_loss=val_loss_avg.avg.item(),
152
        val_topk=val_topk_avg.avg.item(),
153
        lr=lrs[-1],
154
    )
155
    epoch_logs.append(logs)
156
    pbar.set_postfix(**logs)
157
    pd.DataFrame(epoch_logs).to_csv(outdir + '/epoch-logs.csv')
158
    
159
print(datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
Metadata changed
xxxxxxxxxx
4
 
1
{
2
  "code_folding": [
3
  ]
4
}
⇛⇚
xxxxxxxxxx
3
 
1
{
2
​
3
}
Outputs changed
num_epochs:100 batch_size:300 lr:1e-07
2022-10-30 17:44:55
Will be saving model checkpoints to checkpoints/clip_image_vit_subj01_epoch#.pth
num_epochs:100 batch_size:300 lr:0.0003
2022-11-16 17:07:34
Will be saving model checkpoints to checkpoints/clip_image_vit_subj01_epoch#.pth
application/vnd.jupyter.stdout
Output added
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [6:38:59<00:00, 239.39s/it, loss=0.447, lr=0.0003, topk=0.844, val_loss=3.32, val_topk=0.301]
application/vnd.jupyter.stderr
Output added
2022-11-16 23:46:34
application/vnd.jupyter.stdout
Output added
application/vnd.jupyter.stderr
Output deleted
Loss: 3.892 | VLoss: 3.818  | TopK%: 0.432 | VTopK%: 0.277 | lr0.00000:   9%|▉         | 6/64 [12:48<2:03:52, 128.15s/it]

KeyboardInterrupt

application/vnd.jupyter.stderr
In [49]:
xxxxxxxxxx
1
 
1
plot_training()
Metadata changed
xxxxxxxxxx
num_epochs:100 batch_size:300 lr:0.0003
application/vnd.jupyter.stdout
In [55]:
xxxxxxxxxx
1
 
1
pd.DataFrame(epoch_logs).plot(subplots=True);
Metadata changed
xxxxxxxxxx
In [53]:
xxxxxxxxxx
1
 
1
train_topk_avg.sum, train_topk_avg.count, train_topk_avg.avg
Metadata changed
xxxxxxxxxx
(tensor(21086., device='cuda:0'), 24983, tensor(0.8440, device='cuda:0'))
text/plain
In [54]:
xxxxxxxxxx
1
 
1
train_percent_correct[-1], len(train_percent_correct), sum(train_percent_correct)
Metadata changed
xxxxxxxxxx
(0.9036144018173218, 8400, 6814.349324496463)
text/plain
xxxxxxxxxx
1
 
1
## save model
Metadata changed
xxxxxxxxxx
In [59]:
xxxxxxxxxx
1
 
1
#!rm -rf checkpoints/*.pth
Metadata changed
xxxxxxxxxx
In [60]:
xxxxxxxxxx
1
 
1
model_name
Metadata changed
xxxxxxxxxx
'clip_image_vit'
text/plain
In [61]:
xxxxxxxxxx
1
 
1
epoch
Metadata changed
xxxxxxxxxx
99
text/plain
In [64]:
xxxxxxxxxx
1
 
1
ckpt_path = f'checkpoints/{model_name}_subj01_epoch{epoch}.pth'
Metadata changed
xxxxxxxxxx
In [65]:
xxxxxxxxxx
12
 
1
torch.save({
2
    'epoch': epoch,
3
    'model_state_dict': brain_net.state_dict(),
4
    'optimizer_state_dict': optimizer.state_dict(),
5
    'train_losses': train_losses,
6
    'val_losses': val_losses,
7
    'train_percent_correct': train_percent_correct,
8
    'val_percent_correct': val_percent_correct,
9
    'lrs': lrs,
10
    }, 
11
    ckpt_path
12
)
Metadata changed
xxxxxxxxxx
In [58]:
xxxxxxxxxx
11
 
1
print(f"num_epochs:{num_epochs} batch_size:{batch_size} lr:{initial_learning_rate}")
2
fig, (ax1, ax2, ax3, ax4) = plt.subplots(1, 4, figsize=(17,3))
3
ax1.set_title(f"Training Loss\n(final={train_losses[-1]})")
4
ax1.plot(train_losses)
5
ax2.set_title(f"Training Performance\n(final={train_percent_correct[-1]})")
6
ax2.plot(train_percent_correct)
7
ax3.set_title(f"Val Loss\n(final={val_losses[-1]})")
8
ax3.plot(val_losses)
9
ax4.set_title(f"Val Performance\n(final={val_percent_correct[-1]})")
10
ax4.plot(val_percent_correct)
11
plt.show()
Metadata changed
xxxxxxxxxx

Plot losses from saved model¶

In [73]:
In [66]:
xxxxxxxxxx
20
 
1
# Loading 
2
ckpt_path = 'checkpoints/clip_image_vit_subj01_epoch20.pth' 
3
checkpoint = torch.load(ckpt_path, map_location=device)
4
print(f"Plotting results from {ckpt_path}")
5
​
6
train_losses=checkpoint['train_losses']
7
train_percent_correct=checkpoint['train_percent_correct']
8
val_losses=checkpoint['val_losses']
9
val_percent_correct=checkpoint['val_percent_correct']
10
​
11
fig, (ax1, ax2, ax3, ax4) = plt.subplots(1, 4, figsize=(17,3))
12
ax1.set_title(f"Training Loss\n(final={train_losses[-1]})")
13
ax1.plot(train_losses)
14
ax2.set_title(f"Training Performance\n(final={train_percent_correct[-1]})")
15
ax2.plot(train_percent_correct)
16
ax3.set_title(f"Val Loss\n(final={val_losses[-1]})")
17
ax3.plot(val_losses)
18
ax4.set_title(f"Val Performance\n(final={val_percent_correct[-1]})")
19
ax4.plot(val_percent_correct)
20
plt.show()
⇛⇚
xxxxxxxxxx
22
 
1
def plot_saved(ckpt_path):
2
    # Loading 
3
    # ckpt_path = 'checkpoints/clip_image_vit_subj01_epoch20.pth' 
4
    
5
    checkpoint = torch.load(ckpt_path, map_location=device)
6
    print(f"Plotting results from {ckpt_path}")
7
​
8
    train_losses=checkpoint['train_losses']
9
    train_percent_correct=checkpoint['train_percent_correct']
10
    val_losses=checkpoint['val_losses']
11
    val_percent_correct=checkpoint['val_percent_correct']
12
​
13
    fig, (ax1, ax2, ax3, ax4) = plt.subplots(1, 4, figsize=(17,3))
14
    ax1.set_title(f"Training Loss\n(final={train_losses[-1]})")
15
    ax1.plot(train_losses)
16
    ax2.set_title(f"Training Performance\n(final={train_percent_correct[-1]})")
17
    ax2.plot(train_percent_correct)
18
    ax3.set_title(f"Val Loss\n(final={val_losses[-1]})")
19
    ax3.plot(val_losses)
20
    ax4.set_title(f"Val Performance\n(final={val_percent_correct[-1]})")
21
    ax4.plot(val_percent_correct)
22
    plt.show()
Outputs changed
Output deleted
Plotting results from checkpoints/clip_image_vit_subj01_epoch20.pth
application/vnd.jupyter.stdout
Output deleted
In [67]:
xxxxxxxxxx
1
 
1
plot_saved(ckpt_path)
Metadata changed
xxxxxxxxxx
Plotting results from checkpoints/clip_image_vit_subj01_epoch99.pth
application/vnd.jupyter.stdout

Evaluating Top-K Image Retrieval¶

Restart kernel, run "import packages & functions" and "initialize network" cells, and then run below cells.

In [70]:
xxxxxxxxxx
1
 
1
num_samples, batch_size, num_workers, num_worker_batches
Metadata changed
xxxxxxxxxx
(492, 300, 1, 2)
text/plain
In [90]:
xxxxxxxxxx
36
 
1
# num_samples = 492
2
# batch_size = 300
3
# num_batches = 1
4
# num_workers = 1
5
# num_worker_batches = 1
6
​
7
preproc_vox, preproc_img = get_preprocs()
8
​
9
# url = "/scratch/gpfs/KNORMAN/webdataset_nsd/webdataset_split/val/val_subj01_0.tar"
10
# val_data = wds.DataPipeline([wds.ResampledShards(url),
11
#                     wds.tarfile_to_samples(),
12
#                     wds.decode("torch"),
13
#                     wds.rename(images="jpg;png", voxels="nsdgeneral.npy", 
14
#                                 embs="sgxl_emb.npy", trial="trial.npy"),
15
#                     wds.map_dict(images=preproc_img),
16
#                     wds.to_tuple("voxels", "images", "trial"),
17
#                     wds.batched(batch_size, partial=True),
18
#                 ]).with_epoch(num_worker_batches)
19
# val_dl = wds.WebLoader(val_data, num_workers=num_workers,
20
#                        batch_size=None, shuffle=False, persistent_workers=True)
21
​
22
url = os.path.join(NAT_SCENE, "val", SUBJ_FORMAT_VAL)
23
​
24
val_data = wds.DataPipeline([
25
                    # wds.ResampledShards(url), # <TODO> switch back to this once I understand it
26
                    wds.SimpleShardList(url),
27
                    wds.tarfile_to_samples(),
28
                    wds.decode("torch"),
29
                    wds.rename(images="jpg;png", voxels=VOXELS_KEY, embs="sgxl_emb.npy", trial="trial.npy"),
30
                    wds.map_dict(images=preproc_img),
31
                    wds.to_tuple("voxels", emb_name, "trial"),
32
                    wds.batched(batch_size, partial=True),
33
                ]).with_epoch(1) #num_worker_batches)
34
​
35
val_dl = wds.WebLoader(val_data, num_workers=num_workers,
36
                       batch_size=None, shuffle=False, persistent_workers=True)
Metadata changed
xxxxxxxxxx
In [67]:
xxxxxxxxxx
25
 
1
num_samples = 492
2
batch_size = 300
3
num_batches = 1
4
num_workers = 1
5
num_worker_batches = 1
6
​
7
preproc_vox = transforms.Compose([transforms.ToTensor(),torch.nan_to_num])
8
preproc_img = transforms.Compose([
9
                    transforms.Resize(size=(224,224)),
10
                    transforms.Normalize(mean=mean,
11
                                         std=std),
12
                ])
13
​
14
url = "/scratch/gpfs/KNORMAN/webdataset_nsd/webdataset_split/val/val_subj01_0.tar"
15
val_data = wds.DataPipeline([wds.ResampledShards(url),
16
                    wds.tarfile_to_samples(),
17
                    wds.decode("torch"),
18
                    wds.rename(images="jpg;png", voxels="nsdgeneral.npy", 
19
                                embs="sgxl_emb.npy", trial="trial.npy"),
20
                    wds.map_dict(images=preproc_img),
21
                    wds.to_tuple("voxels", "images", "trial"),
22
                    wds.batched(batch_size, partial=True),
23
                ]).with_epoch(num_worker_batches)
24
val_dl = wds.WebLoader(val_data, num_workers=num_workers,
25
                       batch_size=None, shuffle=False, persistent_workers=True)
Metadata changed
xxxxxxxxxx
In [68]:
In [82]:
xxxxxxxxxx
41
 
1
clip_model, _ = clip.load("ViT-L/14", device=device)
2
resnet_model, _ = clip.load("RN50", device=device)
3
clip_model.eval()
4
resnet_model.eval()
5
​
6
f = h5py.File('/scratch/gpfs/KNORMAN/nsdgeneral_hdf5/COCO_73k_subj_indices.hdf5', 'r')
7
subj01_order = f['subj01'][:]
8
f.close()
9
​
10
# curated the COCO annotations in the same way as the mind_reader (Lin Sprague Singh) preprint
11
annots = np.load('/scratch/gpfs/KNORMAN/nsdgeneral_hdf5/COCO_73k_annots_curated.npy',allow_pickle=True)
12
subj01_annots = annots[subj01_order]
13
​
14
def text_tokenize(annots):
15
    for i,b in enumerate(annots):
16
        t = ''
17
        while t == '':
18
            rand = torch.randint(5,(1,1))[0][0]
19
            t = b[0,rand]
20
        if i==0:
21
            txt = np.array(t)
22
        else:
23
            txt = np.vstack((txt,t))
24
    txt = txt.flatten()
25
    return clip.tokenize(txt)
26
​
27
def clip_text_embedder(text_token):
28
    with torch.no_grad():
29
        text_features = clip_model.encode_text(text_token.to(device))
30
    return text_features
31
​
32
def clip_image_embedder(image):
33
    with torch.no_grad():
34
        image_features = clip_model.encode_image(image.to(device))
35
        image_features = torch.clamp(image_features,-1.5,1.5) 
36
    return image_features    
37
​
38
def resnet_image_embedder(image):
39
    with torch.no_grad():
40
        image_features = resnet_model.encode_image(image.to(device))
41
    return image_features   
⇛⇚
xxxxxxxxxx
45
 
1
clip_model, _ = clip.load("ViT-L/14", device=device)
2
# resnet_model, _ = clip.load("RN50", device=device)
3
clip_model.eval()
4
# resnet_model.eval()
5
​
6
# f = h5py.File('/scratch/gpfs/KNORMAN/nsdgeneral_hdf5/COCO_73k_subj_indices.hdf5', 'r')
7
# subj01_order = f['subj01'][:]
8
# f.close()
9
​
10
# # curated the COCO annotations in the same way as the mind_reader (Lin Sprague Singh) preprint
11
# annots = np.load('/scratch/gpfs/KNORMAN/nsdgeneral_hdf5/COCO_73k_annots_curated.npy',allow_pickle=True)
12
# subj01_annots = annots[subj01_order]
13
​
14
# def text_tokenize(annots):
15
#     for i,b in enumerate(annots):
16
#         t = ''
17
#         while t == '':
18
#             rand = torch.randint(5,(1,1))[0][0]
19
#             t = b[0,rand]
20
#         if i==0:
21
#             txt = np.array(t)
22
#         else:
23
#             txt = np.vstack((txt,t))
24
#     txt = txt.flatten()
25
#     return clip.tokenize(txt)
26
​
27
# def clip_text_embedder(text_token):
28
#     with torch.no_grad():
29
#         text_features = clip_model.encode_text(text_token.to(device))
30
#     return text_features
31
​
32
# def clip_image_embedder(image):
33
#     with torch.no_grad():
34
#         image_features = clip_model.encode_image(image.to(device))
35
#         image_features = torch.clamp(image_features,-1.5,1.5) 
36
#     return image_features    
37
​
38
def clip_image_embedder(image):
39
    assert model_name == 'clip_image_vit', model_name
40
    return embedder(image)
41
​
42
# def resnet_image_embedder(image):
43
#     with torch.no_grad():
44
#         image_features = resnet_model.encode_image(image.to(device))
45
#     return image_features
In [69]:
In [76]:
xxxxxxxxxx
17
 
1
brain_net = BrainNetwork(768) 
2
​
3
brain_net_clip_img = brain_net.to(device)
4
checkpoint = torch.load('checkpoints/clip_image_vit_subj01_epoch20.pth', map_location=device)
5
brain_net_clip_img.load_state_dict(checkpoint['model_state_dict'])
6
brain_net_clip_img.eval()
7
​
8
brain_net_clip_text = brain_net.to(device)
9
checkpoint = torch.load('checkpoints/clip_text_vit_subj01_epoch20.pth', map_location=device)
10
brain_net_clip_text.load_state_dict(checkpoint['model_state_dict'])
11
brain_net_clip_text.eval()
12
​
13
brain_net = BrainNetwork(1024) 
14
brain_net_resnet_img = brain_net.to(device)
15
checkpoint = torch.load('checkpoints/clip_image_resnet_subj01_epoch42.pth', map_location=device)
16
brain_net_resnet_img.load_state_dict(checkpoint['model_state_dict'])
17
brain_net_resnet_img.eval()
⇛⇚
xxxxxxxxxx
18
 
1
brain_net = BrainNetwork(768) 
2
​
3
brain_net_clip_img = brain_net.to(device)
4
# checkpoint = torch.load('checkpoints/clip_image_vit_subj01_epoch20.pth', map_location=device)
5
checkpoint = torch.load(ckpt_path, map_location=device)
6
brain_net_clip_img.load_state_dict(checkpoint['model_state_dict'])
7
brain_net_clip_img.eval()
8
​
9
# brain_net_clip_text = brain_net.to(device)
10
# checkpoint = torch.load('checkpoints/clip_text_vit_subj01_epoch20.pth', map_location=device)
11
# brain_net_clip_text.load_state_dict(checkpoint['model_state_dict'])
12
# brain_net_clip_text.eval()
13
​
14
# brain_net = BrainNetwork(1024) 
15
# brain_net_resnet_img = brain_net.to(device)
16
# checkpoint = torch.load('checkpoints/clip_image_resnet_subj01_epoch42.pth', map_location=device)
17
# brain_net_resnet_img.load_state_dict(checkpoint['model_state_dict'])
18
# brain_net_resnet_img.eval()
Outputs changed
Output added
BrainNetwork(
  (mlp): Sequential(
    (0): Linear(in_features=15724, out_features=4096, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.1, inplace=False)
    (3): Linear(in_features=4096, out_features=2048, bias=True)
    (4): ReLU()
    (5): Dropout(p=0.1, inplace=False)
    (6): Linear(in_features=2048, out_features=1024, bias=True)
    (7): ReLU()
    (8): Dropout(p=0.1, inplace=False)
    (9): Linear(in_features=1024, out_features=768, bias=True)
  )
)
text/plain
Output deleted
BrainNetwork(
  (conv): Sequential(
    (0): Conv1d(1, 32, kernel_size=(3,), stride=(1,))
    (1): Dropout1d(p=0.1, inplace=False)
    (2): ReLU()
    (3): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (lin): Linear(in_features=7861, out_features=7861, bias=True)
  (relu): ReLU()
  (lin1): Linear(in_features=251552, out_features=1024, bias=True)
)
text/plain
In [70]:
In [91]:
xxxxxxxxxx
105
 
1
​
4
        with torch.cuda.amp.autocast():
5
            voxel = voxel.to(device)
6
            embt = text_tokenize(subj01_annots[trial_idx]).to(device)
7
            
8
            emb0=[]; emb1 = []; emb2 =[]
9
            for m in np.arange(0,batch_size,minibatch):
10
                if m==0:
11
                    emb0 = clip_image_embedder(emb[m:m+minibatch]).detach().cpu()
12
                    emb1 = resnet_image_embedder(emb[m:m+minibatch]).detach().cpu()
13
                    emb2 = clip_text_embedder(embt[m:m+minibatch]).detach().cpu()
14
                else:
15
                    emb0 = torch.vstack((emb0,clip_image_embedder(emb[m:m+minibatch]).detach().cpu()))
16
                    emb1 = torch.vstack((emb1,resnet_image_embedder(emb[m:m+minibatch]).detach().cpu()))
17
                    emb2 = torch.vstack((emb2,clip_text_embedder(embt[m:m+minibatch]).detach().cpu()))
18
​
19
            emb0 = emb0.to(device)
20
            emb1 = emb1.to(device)
21
            emb2 = emb2.to(device)
22
            
23
            emb_0 = brain_net_clip_img(voxel)
24
            emb_1 = brain_net_resnet_img(voxel)
25
            emb_2 = brain_net_clip_text(voxel)
26
            
27
            labels = torch.arange(len(emb0)).to(device)
28
            similarities0 = batchwise_cosine_similarity(emb0,emb_0)
29
            similarities1 = batchwise_cosine_similarity(emb1,emb_1)
30
            similarities2 = batchwise_cosine_similarity(emb2,emb_2)
31
            
32
            # how to combine the different models?
33
            similaritiesx = similarities0/2+similarities1+similarities2/2
34
​
35
            print("CLIP IMG")
36
​
51
                plt.show()
52
                
53
            print("\nRESNET50 IMG")
54
            
55
            percent_correct = topk(similarities1,labels,k=1)
56
            print("percent_correct",percent_correct)
57
            
58
            similarities1=np.array(similarities1.detach().cpu())
59
            for trial in range(4):
60
                fig, ax = plt.subplots(nrows=1, ncols=6, figsize=(11,6))
61
                ax[0].imshow(torch_to_Image(emb[trial]))
62
                ax[0].set_title("original\nimage")
63
                ax[0].axis("off")
64
                for attempt in range(5):
65
                    which = np.flip(np.argsort(similarities1[trial]))[attempt]
66
                    ax[attempt+1].imshow(torch_to_Image(emb[which]))
67
                    ax[attempt+1].set_title(f"Top {attempt}")
68
                    ax[attempt+1].axis("off")
69
                plt.show()
70
                
71
            print("\nCLIP TEXT")
72
            
73
            percent_correct = topk(similarities2,labels,k=1)
74
            print("percent_correct",percent_correct)
75
            
76
            similarities2=np.array(similarities2.detach().cpu())
77
            for trial in range(4):
78
                fig, ax = plt.subplots(nrows=1, ncols=6, figsize=(11,6))
79
                ax[0].imshow(torch_to_Image(emb[trial]))
80
                ax[0].set_title("original\nimage")
81
                ax[0].axis("off")
82
                for attempt in range(5):
83
                    which = np.flip(np.argsort(similarities2[trial]))[attempt]
84
                    ax[attempt+1].imshow(torch_to_Image(emb[which]))
85
                    ax[attempt+1].set_title(f"Top {attempt}")
86
                    ax[attempt+1].axis("off")
87
                plt.show()
88
                
89
            print("\nCOMBINED")
90
            
91
            percent_correct = topk(similaritiesx,labels,k=1)
92
            print("percent_correct",percent_correct)
93
            
94
            similaritiesx=np.array(similaritiesx.detach().cpu())
95
            for trial in range(4):
96
                fig, ax = plt.subplots(nrows=1, ncols=6, figsize=(11,6))
97
                ax[0].imshow(torch_to_Image(emb[trial]))
98
                ax[0].set_title("original\nimage")
99
                ax[0].axis("off")
100
                for attempt in range(5):
101
                    which = np.flip(np.argsort(similaritiesx[trial]))[attempt]
102
                    ax[attempt+1].imshow(torch_to_Image(emb[which]))
103
                    ax[attempt+1].set_title(f"Top {attempt}")
104
                    ax[attempt+1].axis("off")
105
                plt.show()
⇛⇚
xxxxxxxxxx
106
 
1
​
4
        with torch.cuda.amp.autocast():
5
            voxel = voxel.to(device)
6
            # embt = text_tokenize(subj01_annots[trial_idx]).to(device)
7
            
8
            emb0=[]; emb1 = []; emb2 =[]
9
            for m in np.arange(0,batch_size,minibatch):
10
                if m==0:
11
                    emb0 = clip_image_embedder(emb[m:m+minibatch]).detach().cpu()
12
                    # emb1 = resnet_image_embedder(emb[m:m+minibatch]).detach().cpu()
13
                    # emb2 = clip_text_embedder(embt[m:m+minibatch]).detach().cpu()
14
                else:
15
                    emb0 = torch.vstack((emb0,clip_image_embedder(emb[m:m+minibatch]).detach().cpu()))
16
                    # emb1 = torch.vstack((emb1,resnet_image_embedder(emb[m:m+minibatch]).detach().cpu()))
17
                    # emb2 = torch.vstack((emb2,clip_text_embedder(embt[m:m+minibatch]).detach().cpu()))
18
​
19
            emb0 = emb0.to(device)
20
            # emb1 = emb1.to(device)
21
            # emb2 = emb2.to(device)
22
            
23
            emb_0 = brain_net_clip_img(voxel)
24
            emb_0 = nn.functional.normalize(emb_0, dim=-1) # <TODO> move into network
25
            # emb_1 = brain_net_resnet_img(voxel)
26
            # emb_2 = brain_net_clip_text(voxel)
27
            
28
            labels = torch.arange(len(emb0)).to(device)
29
            similarities0 = batchwise_cosine_similarity(emb0,emb_0)
30
            # similarities1 = batchwise_cosine_similarity(emb1,emb_1)
31
            # similarities2 = batchwise_cosine_similarity(emb2,emb_2)
32
            
33
            # how to combine the different models?
34
             #similaritiesx = similarities0/2+similarities1+similarities2/2
35
​
36
            print("CLIP IMG")
37
​
52
                plt.show()
53
                
54
#             print("\nRESNET50 IMG")
55
            
56
#             percent_correct = topk(similarities1,labels,k=1)
57
#             print("percent_correct",percent_correct)
58
            
59
#             similarities1=np.array(similarities1.detach().cpu())
60
#             for trial in range(4):
61
#                 fig, ax = plt.subplots(nrows=1, ncols=6, figsize=(11,6))
62
#                 ax[0].imshow(torch_to_Image(emb[trial]))
63
#                 ax[0].set_title("original\nimage")
64
#                 ax[0].axis("off")
65
#                 for attempt in range(5):
66
#                     which = np.flip(np.argsort(similarities1[trial]))[attempt]
67
#                     ax[attempt+1].imshow(torch_to_Image(emb[which]))
68
#                     ax[attempt+1].set_title(f"Top {attempt}")
69
#                     ax[attempt+1].axis("off")
70
#                 plt.show()
71
                
72
#             print("\nCLIP TEXT")
73
            
74
#             percent_correct = topk(similarities2,labels,k=1)
75
#             print("percent_correct",percent_correct)
76
            
77
#             similarities2=np.array(similarities2.detach().cpu())
78
#             for trial in range(4):
79
#                 fig, ax = plt.subplots(nrows=1, ncols=6, figsize=(11,6))
80
#                 ax[0].imshow(torch_to_Image(emb[trial]))
81
#                 ax[0].set_title("original\nimage")
82
#                 ax[0].axis("off")
83
#                 for attempt in range(5):
84
#                     which = np.flip(np.argsort(similarities2[trial]))[attempt]
85
#                     ax[attempt+1].imshow(torch_to_Image(emb[which]))
86
#                     ax[attempt+1].set_title(f"Top {attempt}")
87
#                     ax[attempt+1].axis("off")
88
#                 plt.show()
89
                
90
#             print("\nCOMBINED")
91
            
92
#             percent_correct = topk(similaritiesx,labels,k=1)
93
#             print("percent_correct",percent_correct)
94
            
95
#             similaritiesx=np.array(similaritiesx.detach().cpu())
96
#             for trial in range(4):
97
#                 fig, ax = plt.subplots(nrows=1, ncols=6, figsize=(11,6))
98
#                 ax[0].imshow(torch_to_Image(emb[trial]))
99
#                 ax[0].set_title("original\nimage")
100
#                 ax[0].axis("off")
101
#                 for attempt in range(5):
102
#                     which = np.flip(np.argsort(similaritiesx[trial]))[attempt]
103
#                     ax[attempt+1].imshow(torch_to_Image(emb[which]))
104
#                     ax[attempt+1].set_title(f"Top {attempt}")
105
#                     ax[attempt+1].axis("off")
106
#                 plt.show()
Outputs changed
CLIP IMG
percent_correct tensor(0.1333, device='cuda:0')
CLIP IMG
percent_correct tensor(0.3133, device='cuda:0')
application/vnd.jupyter.stdout
Output added
Output added
Output added
Output added
Output deleted
Output deleted
Output deleted
Output deleted
Output deleted
RESNET50 IMG
percent_correct tensor(0.3567, device='cuda:0')
application/vnd.jupyter.stdout
Output deleted
Output deleted
Output deleted
Output deleted
Output deleted
CLIP TEXT
percent_correct tensor(0.1633, device='cuda:0')
application/vnd.jupyter.stdout
Output deleted
Output deleted
Output deleted
Output deleted
Output deleted
COMBINED
percent_correct tensor(0.3833, device='cuda:0')
application/vnd.jupyter.stdout
Output deleted
Output deleted
Output deleted
Output deleted
In [ ]:
xxxxxxxxxx